import sys
import os
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import torch as t
import argparse
from cut.utils import load_model
import json

def prepare_data(data, batch_size=8):
    """
    Return a generator of batches of the form (text_batch, answers_batch)
    """
    batch = []
    for row in data:

        question = f"""\
The following are multiple choice questions (with answers).

{row['question']}
A. {row['choices'][0]}
B. {row['choices'][1]}
C. {row['choices'][2]}
D. {row['choices'][3]}
Answer:
"""
        ans = row['answer']
        batch.append((question, ans))
        if len(batch) == batch_size:
            yield batch
            batch = []

def get_accuracy(model, tokenizer, batches):

    # get token idxs for A, B, C, D
    A_idx = tokenizer.encode("A")[-1]
    B_idx = tokenizer.encode("B")[-1]
    C_idx = tokenizer.encode("C")[-1]
    D_idx = tokenizer.encode("D")[-1]
    choice_idxs = t.tensor([A_idx, B_idx, C_idx, D_idx]).to(model.device)


    corrects = []
    for batch in tqdm(batches):
        texts = [x[0] for x in batch]
        answers = t.tensor([x[1] for x in batch]).to(model.device)
        inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
        outputs = model(**inputs).logits[:, -1, choice_idxs]
        predictions = outputs.argmax(dim=-1)
        corrects.extend((predictions == answers).tolist())
    return sum(corrects) / len(corrects)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate a model on a multiple choice dataset')
    # parser.add_argument('--model_name_or_path', type=str, default="HuggingFaceH4/zephyr-7b-beta")
    # parser.add_argument('--model_name_or_path', type=str, default="models/HuggingFaceH4/zephyr-7b-beta_alpha-[100.0, 100.0]_batches-80_layer-7_2024-06-13-17-37-29")
    # parser.add_argument('--model_name_or_path', type=str, default="models/HuggingFaceH4/ft-skip_split0--zephyr-7b-beta_alpha-[100.0, 100.0]_batches-80_layer-7_2024-06-13-17-37-29")
    parser.add_argument('--model_name_or_path', type=str, default="models/HuggingFaceH4/zephyr-7b-beta_alpha-[700.0]_batches-80_layer-7_2024-06-13-20-45-01")
    # parser.add_argument('--data_path', type=str, default='data/wmdp/wmdp-cyber.jsonl')
    # parser.add_argument('--data_path', type=str, default='data/mmlu/mmlu.jsonl')
    parser.add_argument('--data_path', type=str, default='data/mmlu/mmlu_3000.jsonl')
    # parser.add_argument('--data_path', type=str, default='data/dates-years/split_0.jsonl')
    parser.add_argument('--batch_size', type=int, default=16, help='The batch size to use for evaluation')
    args = parser.parse_args()

    model, tokenizer = load_model(args.model_name_or_path)
    t.set_grad_enabled(False)

    reader = (json.loads(line) for line in open(args.data_path, "r"))
    batches = prepare_data(reader, args.batch_size)
    accuracy = get_accuracy(model, tokenizer, batches)
    print(f"Accuracy for {args.model_name_or_path} on dataset {args.data_path}: {accuracy:.2f}")



